{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph Neural Networks for Computer Vision\n",
    "\n",
    "## 1. Initial Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "from torch_geometric.nn import GCNConv, TopKPooling\n",
    "from torch.nn import MultiheadAttention\n",
    "from torch_geometric.data import Data, Batch\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Graph Construction Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_face_graph(landmarks, threshold=2.0):\n",
    "    G = nx.Graph()\n",
    "    for i, landmark in enumerate(landmarks):\n",
    "        G.add_node(i, pos=landmark)\n",
    "    \n",
    "    for i in range(len(landmarks)):\n",
    "        for j in range(i+1, len(landmarks)):\n",
    "            if np.linalg.norm(landmarks[i] - landmarks[j]) < threshold:\n",
    "                G.add_edge(i, j)\n",
    "    return G\n",
    "\n",
    "\n",
    "def create_pixel_graph(image, connectivity=4):\n",
    "    height, width = image.shape[:2]\n",
    "    G = nx.Graph()\n",
    "    \n",
    "    for i in range(height):\n",
    "        for j in range(width):\n",
    "            node_id = i * width + j\n",
    "            G.add_node(node_id, features=image[i, j], pos=(i, j))\n",
    "            \n",
    "            if connectivity == 4:\n",
    "                neighbors = [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]\n",
    "            elif connectivity == 8:\n",
    "                neighbors = [(i-1, j), (i+1, j), (i, j-1), (i, j+1),\n",
    "                           (i-1, j-1), (i-1, j+1), (i+1, j-1), (i+1, j+1)]\n",
    "                \n",
    "            for ni, nj in neighbors:\n",
    "                if 0 <= ni < height and 0 <= nj < width:\n",
    "                    neighbor_id = ni * width + nj\n",
    "                    G.add_edge(node_id, neighbor_id)\n",
    "    return G"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Basic GNN Models\n",
    "\n",
    "### 3.1 Simple GCN for Image Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleGCN(nn.Module):\n",
    "    def __init__(self, num_node_features, num_classes):\n",
    "        super(SimpleGCN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_node_features, 16)\n",
    "        self.conv2 = GCNConv(16, num_classes)\n",
    "        \n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = torch.relu(x)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Hierarchical GCN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HierarchicalGCN(nn.Module):\n",
    "    def __init__(self, num_node_features, num_classes):\n",
    "        super(HierarchicalGCN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_node_features, 64)\n",
    "        self.pool1 = TopKPooling(64, ratio=0.8)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.pool2 = TopKPooling(32, ratio=0.8)\n",
    "        self.conv3 = GCNConv(32, num_classes)\n",
    "        \n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n",
    "        x = self.conv3(x, edge_index)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Object Detection and Segmentation Models\n",
    "\n",
    "### 4.1 Object Proposal GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ObjectProposalGNN(nn.Module):\n",
    "    def __init__(self, num_node_features):\n",
    "        super(ObjectProposalGNN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_node_features, 64)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.conv3 = GCNConv(32, 1)\n",
    "        \n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = torch.relu(self.conv1(x, edge_index))\n",
    "        x = torch.relu(self.conv2(x, edge_index))\n",
    "        x = self.conv3(x, edge_index)\n",
    "        return x\n",
    "\n",
    "class InstanceSegmentationGNN(nn.Module):\n",
    "    def __init__(self, num_features):\n",
    "        super(InstanceSegmentationGNN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_features, 64)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.conv3 = GCNConv(32, 1)\n",
    "        \n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = torch.relu(self.conv1(x, edge_index))\n",
    "        x = torch.relu(self.conv2(x, edge_index))\n",
    "        mask_prob = torch.sigmoid(self.conv3(x, edge_index))\n",
    "        return mask_prob\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Multimodal GNN Models\n",
    "\n",
    "### 5.1 Visual-Textual GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VisualTextualGNN(nn.Module):\n",
    "    def __init__(self, image_feature_dim, word_embedding_dim, hidden_dim):\n",
    "        super(VisualTextualGNN, self).__init__()\n",
    "        self.image_encoder = GCNConv(image_feature_dim, hidden_dim)\n",
    "        self.text_encoder = GCNConv(word_embedding_dim, hidden_dim)\n",
    "        self.fusion_layer = GCNConv(hidden_dim, hidden_dim)\n",
    "        self.output_layer = nn.Linear(hidden_dim, 1)\n",
    "        \n",
    "    def forward(self, image_features, word_embeddings, edge_index):\n",
    "        image_enc = self.image_encoder(image_features, edge_index)\n",
    "        text_enc = self.text_encoder(word_embeddings, edge_index)\n",
    "        fused = self.fusion_layer(image_enc + text_enc, edge_index)\n",
    "        return self.output_layer(fused)\n",
    "\n",
    "class CrossModalRetrievalGNN(nn.Module):\n",
    "    def __init__(self, image_dim, text_dim, hidden_dim):\n",
    "        super(CrossModalRetrievalGNN, self).__init__()\n",
    "        self.image_encoder = GCNConv(image_dim, hidden_dim)\n",
    "        self.text_encoder = GCNConv(text_dim, hidden_dim)\n",
    "        self.fusion = GCNConv(hidden_dim, hidden_dim)\n",
    "        \n",
    "    def forward(self, image_features, text_features, edge_index):\n",
    "        img_enc = self.image_encoder(image_features, edge_index)\n",
    "        text_enc = self.text_encoder(text_features, edge_index)\n",
    "        fused = self.fusion(img_enc + text_enc, edge_index)\n",
    "        return fused\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Advanced Vision Models\n",
    "\n",
    "### 6.1 Relational Object Detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelationalObjectDetectionGNN(nn.Module):\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(RelationalObjectDetectionGNN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_features, 64)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.classifier = nn.Linear(32, num_classes)\n",
    "        self.bbox_regressor = nn.Linear(32, 4)\n",
    "        \n",
    "    def forward(self, x, edge_index):\n",
    "        x = torch.relu(self.conv1(x, edge_index))\n",
    "        x = torch.relu(self.conv2(x, edge_index))\n",
    "        class_scores = self.classifier(x)\n",
    "        bbox_refinement = self.bbox_regressor(x)\n",
    "        return class_scores, bbox_refinement\n",
    "\n",
    "class PanopticSegmentationGNN(nn.Module):\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(PanopticSegmentationGNN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_features, 64)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.classifier = nn.Linear(32, num_classes)\n",
    "        self.instance_predictor = nn.Linear(32, 1)\n",
    "        \n",
    "    def forward(self, x, edge_index):\n",
    "        x = torch.relu(self.conv1(x, edge_index))\n",
    "        x = torch.relu(self.conv2(x, edge_index))\n",
    "        semantic_pred = self.classifier(x)\n",
    "        instance_pred = self.instance_predictor(x)\n",
    "        return semantic_pred, instance_pred\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Navigation and Hierarchical Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VisualLanguageNavigationGNN(nn.Module):\n",
    "    def __init__(self, visual_dim, instruction_dim, hidden_dim, num_actions=4):\n",
    "        super(VisualLanguageNavigationGNN, self).__init__()\n",
    "        self.visual_gnn = GCNConv(visual_dim, hidden_dim)\n",
    "        self.instruction_gnn = GCNConv(instruction_dim, hidden_dim)\n",
    "        self.navigation_head = nn.Linear(hidden_dim * 2, num_actions)\n",
    "        \n",
    "    def forward(self, visual_obs, instructions, scene_graph, instr_graph):\n",
    "        visual_feat = self.visual_gnn(visual_obs, scene_graph)\n",
    "        instr_feat = self.instruction_gnn(instructions, instr_graph)\n",
    "        combined = torch.cat([visual_feat, instr_feat], dim=-1)\n",
    "        action_logits = self.navigation_head(combined)\n",
    "        return action_logits\n",
    "\n",
    "class HierarchicalImageGNN(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims=[64, 32, 16]):\n",
    "        super(HierarchicalImageGNN, self).__init__()\n",
    "        self.levels = len(hidden_dims)\n",
    "        self.gnns = nn.ModuleList()\n",
    "        self.pools = nn.ModuleList()\n",
    "        \n",
    "        curr_dim = input_dim\n",
    "        for hidden_dim in hidden_dims:\n",
    "            self.gnns.append(GCNConv(curr_dim, hidden_dim))\n",
    "            self.pools.append(TopKPooling(hidden_dim, ratio=0.5))\n",
    "            curr_dim = hidden_dim\n",
    "            \n",
    "    def forward(self, x, edge_index, batch):\n",
    "        features = []\n",
    "        for i in range(self.levels):\n",
    "            x = self.gnns[i](x, edge_index)\n",
    "            x, edge_index, _, batch, _, _ = self.pools[i](x, edge_index, None, batch)\n",
    "            features.append(x)\n",
    "        return features\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Testing and Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Face graph created successfully ✓\n",
      "Pixel graph created successfully ✓\n",
      "Testing models...\n",
      "SimpleGCN: Success ✓\n",
      "HierarchicalGCN: Success ✓\n",
      "ObjectProposalGNN: Success ✓\n",
      "InstanceSegmentationGNN: Success ✓\n",
      "VisualTextualGNN: Success ✓\n",
      "RelationalObjectDetectionGNN: Success ✓\n",
      "PanopticSegmentationGNN: Success ✓\n",
      "CrossModalRetrievalGNN: Success ✓\n",
      "VisualLanguageNavigationGNN: Success ✓\n",
      "HierarchicalImageGNN: Success ✓\n"
     ]
    }
   ],
   "source": [
    "def test_models():\n",
    "    # Create sample data\n",
    "    num_nodes = 10\n",
    "    num_features = 3\n",
    "    num_classes = 2\n",
    "    edge_index = torch.randint(0, num_nodes, (2, 20))\n",
    "    x = torch.randn(num_nodes, num_features)\n",
    "    batch = torch.zeros(num_nodes, dtype=torch.long)\n",
    "    \n",
    "    # Test each model\n",
    "    models = {\n",
    "        \"SimpleGCN\": SimpleGCN(num_features, num_classes),\n",
    "        \"HierarchicalGCN\": HierarchicalGCN(num_features, num_classes),\n",
    "        \"ObjectProposalGNN\": ObjectProposalGNN(num_features),\n",
    "        \"InstanceSegmentationGNN\": InstanceSegmentationGNN(num_features),\n",
    "        \"VisualTextualGNN\": VisualTextualGNN(num_features, num_features, 16),\n",
    "        \"RelationalObjectDetectionGNN\": RelationalObjectDetectionGNN(num_features, num_classes),\n",
    "        \"PanopticSegmentationGNN\": PanopticSegmentationGNN(num_features, num_classes),\n",
    "        \"CrossModalRetrievalGNN\": CrossModalRetrievalGNN(num_features, num_features, 16),\n",
    "        \"VisualLanguageNavigationGNN\": VisualLanguageNavigationGNN(num_features, num_features, 16),\n",
    "        \"HierarchicalImageGNN\": HierarchicalImageGNN(num_features)\n",
    "    }\n",
    "    \n",
    "    print(\"Testing models...\")\n",
    "    for name, model in models.items():\n",
    "        try:\n",
    "            if name in [\"VisualTextualGNN\", \"CrossModalRetrievalGNN\"]:\n",
    "                output = model(x, x, edge_index)\n",
    "            elif name == \"VisualLanguageNavigationGNN\":\n",
    "                output = model(x, x, edge_index, edge_index)\n",
    "            elif name in [\"HierarchicalGCN\", \"ObjectProposalGNN\", \"InstanceSegmentationGNN\", \"HierarchicalImageGNN\"]:\n",
    "                output = model(x, edge_index, batch)\n",
    "            else:\n",
    "                output = model(x, edge_index)\n",
    "            print(f\"{name}: Success ✓\")\n",
    "        except Exception as e:\n",
    "            print(f\"{name}: Failed ✗ - {str(e)}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Test graph construction\n",
    "    landmarks = np.random.rand(5, 2)\n",
    "    face_graph = create_face_graph(landmarks)\n",
    "    print(\"\\nFace graph created successfully ✓\")\n",
    "    \n",
    "    # Load and process sample image\n",
    "    sample_image = Image.open('street.jpg')\n",
    "    sample_image = sample_image.resize((10, 10))\n",
    "    image = np.array(sample_image)\n",
    "    pixel_graph = create_pixel_graph(image)\n",
    "    print(\"Pixel graph created successfully ✓\")\n",
    "    \n",
    "    # Test all GNN models\n",
    "    test_models()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
